{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.utils.data as data\n",
    "\n",
    "from torch import Tensor\n",
    "from torch.distributions import Bernoulli, Distribution, Independent, Normal, Uniform\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms.functional import to_pil_image, to_tensor\n",
    "from tqdm import tqdm\n",
    "\n",
    "import zuko\n",
    "\n",
    "_ = torch.random.manual_seed(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiwAAAGdCAYAAAAxCSikAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAoMElEQVR4nO3df1RU953/8ReQZfAXo5YyA3ZaQN24Nuo0IFO6JnFPpg45tqfupl305CzI6TFnTeKJ34k1kiYQ15wzxliXJlLpums0pq5su60527i0OXNK9vSUSAvryTaJHnV18UdnBPcwo+QEcmC+f6QZdxSUi+h8GJ6Pc+6Jc+dzP77vmcC8/NzP/dy0WCwWEwAAgMHSk10AAADAzRBYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGuyvZBYyFwcFBXbhwQdOmTVNaWlqyywEAACMQi8V0+fJl5efnKz39xmMoKRFYLly4IJfLlewyAADAKJw9e1af+9znbtgmJQLLtGnTJH1ywtnZ2UmuBgAAjEQ0GpXL5Yp/j99ISgSWTy8DZWdnE1gAABhnRjKdg0m3AADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMa7K9kFAEhdBZvevG7fma3Lk1AJgPFuVCMsDQ0NKigoUFZWljwej9ra2kZ03MGDB5WWlqYVK1Yk7I/FYqqtrVVeXp4mTZokr9erEydOjKY0AACQgiwHlqamJvn9ftXV1amjo0OLFi2Sz+fTxYsXb3jcmTNntGHDBt13333Xvbdt2za9/PLLamxs1JEjRzRlyhT5fD599NFHVssDAAApyHJg2bFjh9asWaPq6mrNnz9fjY2Nmjx5svbs2TPsMQMDA3rkkUe0efNmFRUVJbwXi8VUX1+vZ599Vt/4xje0cOFCvfbaa7pw4YIOHTpk+YQAAEDqsRRY+vv71d7eLq/Xe7WD9HR5vV61trYOe9zf/d3fKTc3V9/+9reve+/06dMKhUIJfdrtdnk8nmH77OvrUzQaTdgAAEDqshRYuru7NTAwIIfDkbDf4XAoFAoNecyvf/1r/dM//ZN279495PufHmelz0AgILvdHt9cLpeV0wAAAOPMbb1L6PLly/qbv/kb7d69Wzk5OWPWb01Njfx+f/x1NBoltAApjLuNAFgKLDk5OcrIyFA4HE7YHw6H5XQ6r2t/6tQpnTlzRl//+tfj+wYHBz/5i++6S8ePH48fFw6HlZeXl9Cn2+0esg6bzSabzWaldAAAMI5ZCiyZmZkqLi5WMBiM35o8ODioYDCoJ5544rr28+bN03/9138l7Hv22Wd1+fJlff/735fL5dKf/MmfyOl0KhgMxgNKNBrVkSNHtHbt2tGdFQBjXTtawkgJgJGwfEnI7/erqqpKJSUlKi0tVX19vXp7e1VdXS1Jqqys1KxZsxQIBJSVlaV77rkn4fjp06dLUsL+9evX64UXXtDcuXNVWFio5557Tvn5+det1wIAACYmy4GloqJCXV1dqq2tVSgUktvtVnNzc3zSbGdnp9LTrd0tvXHjRvX29urRRx9VT0+PlixZoubmZmVlZVktDwAApKC0WCwWS3YRtyoajcputysSiSg7OzvZ5QD4o6Emy47EtZeJmHQLpCYr3988/BAAABiPwAIAAIxHYAEAAMa7rQvHAcBojHbuC4DUxQgLAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4rMMCYMywfgqA24URFgAAYDwCCwAAMB6XhACMSyO5/HRm6/I7UAmAO4HAAmBEmJ8CIJm4JAQAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4/EsIQApa6jnH/FARGB8YoQFAAAYj8ACAACMR2ABAADGI7AAAADjjSqwNDQ0qKCgQFlZWfJ4PGpraxu27U9/+lOVlJRo+vTpmjJlitxut/bv35/QZvXq1UpLS0vYysvLR1MaAABIQZbvEmpqapLf71djY6M8Ho/q6+vl8/l0/Phx5ebmXtd+5syZ+u53v6t58+YpMzNTP//5z1VdXa3c3Fz5fL54u/Lycr366qvx1zabbZSnBAAAUk1aLBaLWTnA4/Fo8eLF2rlzpyRpcHBQLpdL69at06ZNm0bUx7333qvly5dry5Ytkj4ZYenp6dGhQ4esVf9H0WhUdrtdkUhE2dnZo+oDwFVD3Q6cKritGTCHle9vS5eE+vv71d7eLq/Xe7WD9HR5vV61trbe9PhYLKZgMKjjx4/r/vvvT3ivpaVFubm5uvvuu7V27VpdunRp2H76+voUjUYTNgAAkLosXRLq7u7WwMCAHA5Hwn6Hw6Fjx44Ne1wkEtGsWbPU19enjIwM/eAHP9BXv/rV+Pvl5eX6q7/6KxUWFurUqVN65pln9NBDD6m1tVUZGRnX9RcIBLR582YrpQMAgHHsjqx0O23aNB09elRXrlxRMBiU3+9XUVGRli5dKklauXJlvO2CBQu0cOFCzZ49Wy0tLXrwwQev66+mpkZ+vz/+OhqNyuVy3fbzAAAAyWEpsOTk5CgjI0PhcDhhfzgcltPpHPa49PR0zZkzR5Lkdrv1wQcfKBAIxAPLtYqKipSTk6OTJ08OGVhsNhuTcgEAmEAszWHJzMxUcXGxgsFgfN/g4KCCwaDKyspG3M/g4KD6+vqGff/cuXO6dOmS8vLyrJQHAABSlOVLQn6/X1VVVSopKVFpaanq6+vV29ur6upqSVJlZaVmzZqlQCAg6ZP5JiUlJZo9e7b6+vp0+PBh7d+/X7t27ZIkXblyRZs3b9bDDz8sp9OpU6dOaePGjZozZ07Cbc8AAGDishxYKioq1NXVpdraWoVCIbndbjU3N8cn4nZ2dio9/erATW9vrx577DGdO3dOkyZN0rx58/T666+roqJCkpSRkaF3331X+/btU09Pj/Lz87Vs2TJt2bKFyz4Axty1t2xzmzMwPlheh8VErMMCjK1UXoflWgQWIHmsfH/fkbuEAJhtIgUUAOMTDz8EAADGY4QFwIQ21OgSl4kA8zDCAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMN5dyS4AwJ1VsOnNZJcAAJYxwgIAAIxHYAEAAMbjkhAAXOPay2Znti5PUiUAPkVgAYCbGGreDyEGuLO4JAQAAIxHYAEAAMYbVWBpaGhQQUGBsrKy5PF41NbWNmzbn/70pyopKdH06dM1ZcoUud1u7d+/P6FNLBZTbW2t8vLyNGnSJHm9Xp04cWI0pQEAgBRkObA0NTXJ7/errq5OHR0dWrRokXw+ny5evDhk+5kzZ+q73/2uWltb9e6776q6ulrV1dX6xS9+EW+zbds2vfzyy2psbNSRI0c0ZcoU+Xw+ffTRR6M/MwAAkDLSYrFYzMoBHo9Hixcv1s6dOyVJg4ODcrlcWrdunTZt2jSiPu69914tX75cW7ZsUSwWU35+vp566ilt2LBBkhSJRORwOLR3716tXLnypv1Fo1HZ7XZFIhFlZ2dbOR1gwmHhuLHBpFvg1ln5/rY0wtLf36/29nZ5vd6rHaSny+v1qrW19abHx2IxBYNBHT9+XPfff78k6fTp0wqFQgl92u12eTyeYfvs6+tTNBpN2AAAQOqyFFi6u7s1MDAgh8ORsN/hcCgUCg17XCQS0dSpU5WZmanly5frlVde0Ve/+lVJih9npc9AICC73R7fXC6XldMAAADjzB1Zh2XatGk6evSorly5omAwKL/fr6KiIi1dunRU/dXU1Mjv98dfR6NRQgswDC4BAUgFlgJLTk6OMjIyFA6HE/aHw2E5nc5hj0tPT9ecOXMkSW63Wx988IECgYCWLl0aPy4cDisvLy+hT7fbPWR/NptNNpvNSukAAGAcs3RJKDMzU8XFxQoGg/F9g4ODCgaDKisrG3E/g4OD6uvrkyQVFhbK6XQm9BmNRnXkyBFLfQIAgNRl+ZKQ3+9XVVWVSkpKVFpaqvr6evX29qq6ulqSVFlZqVmzZikQCEj6ZL5JSUmJZs+erb6+Ph0+fFj79+/Xrl27JElpaWlav369XnjhBc2dO1eFhYV67rnnlJ+frxUrVozdmQIAgHHLcmCpqKhQV1eXamtrFQqF5Ha71dzcHJ8029nZqfT0qwM3vb29euyxx3Tu3DlNmjRJ8+bN0+uvv66Kiop4m40bN6q3t1ePPvqoenp6tGTJEjU3NysrK2sMThEAAIx3ltdhMRHrsADDY9Lt7cE6LMCtu23rsAAAACQDgQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAY765kFwAA41HBpjcTXp/ZujxJlQATAyMsAADAeIywACnk2n/1A0CqYIQFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPlW6BcYyVbc0x1GfB84WAscMICwAAMN6oAktDQ4MKCgqUlZUlj8ejtra2Ydvu3r1b9913n2bMmKEZM2bI6/Ve13716tVKS0tL2MrLy0dTGgAASEGWA0tTU5P8fr/q6urU0dGhRYsWyefz6eLFi0O2b2lp0apVq/SrX/1Kra2tcrlcWrZsmc6fP5/Qrry8XH/4wx/i2z//8z+P7owAAEDKsRxYduzYoTVr1qi6ulrz589XY2OjJk+erD179gzZ/kc/+pEee+wxud1uzZs3T//4j/+owcFBBYPBhHY2m01OpzO+zZgxY3RnBAAAUo6lwNLf36/29nZ5vd6rHaSny+v1qrW1dUR9fPjhh/r44481c+bMhP0tLS3Kzc3V3XffrbVr1+rSpUvD9tHX16doNJqwAQCA1GUpsHR3d2tgYEAOhyNhv8PhUCgUGlEfTz/9tPLz8xNCT3l5uV577TUFg0G9+OKLevvtt/XQQw9pYGBgyD4CgYDsdnt8c7lcVk4DAACMM3f0tuatW7fq4MGDamlpUVZWVnz/ypUr439esGCBFi5cqNmzZ6ulpUUPPvjgdf3U1NTI7/fHX0ejUUILAAApzNIIS05OjjIyMhQOhxP2h8NhOZ3OGx67fft2bd26Vb/85S+1cOHCG7YtKipSTk6OTp48OeT7NptN2dnZCRsAAEhdlgJLZmamiouLEybMfjqBtqysbNjjtm3bpi1btqi5uVklJSU3/XvOnTunS5cuKS8vz0p5AAAgRVm+JOT3+1VVVaWSkhKVlpaqvr5evb29qq6uliRVVlZq1qxZCgQCkqQXX3xRtbW1OnDggAoKCuJzXaZOnaqpU6fqypUr2rx5sx5++GE5nU6dOnVKGzdu1Jw5c+Tz+cbwVAHgzrp29VtWvgVGz3JgqaioUFdXl2praxUKheR2u9Xc3ByfiNvZ2an09KsDN7t27VJ/f7+++c1vJvRTV1en559/XhkZGXr33Xe1b98+9fT0KD8/X8uWLdOWLVtks9lu8fQAAEAqSIvFYrFkF3GrotGo7Ha7IpEI81kwofAsofGFERYgkZXvb54lBAAAjEdgAQAAxruj67AAGD0u/wCYyBhhAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxrsr2QUAwERRsOnN6/ad2bo8CZUA4w8jLAAAwHgEFgAAYDwCCwAAMN6oAktDQ4MKCgqUlZUlj8ejtra2Ydvu3r1b9913n2bMmKEZM2bI6/Ve1z4Wi6m2tlZ5eXmaNGmSvF6vTpw4MZrSAABACrIcWJqamuT3+1VXV6eOjg4tWrRIPp9PFy9eHLJ9S0uLVq1apV/96ldqbW2Vy+XSsmXLdP78+Xibbdu26eWXX1ZjY6OOHDmiKVOmyOfz6aOPPhr9mQHjXMGmNxM2AJjI0mKxWMzKAR6PR4sXL9bOnTslSYODg3K5XFq3bp02bdp00+MHBgY0Y8YM7dy5U5WVlYrFYsrPz9dTTz2lDRs2SJIikYgcDof27t2rlStX3rTPaDQqu92uSCSi7OxsK6cDGIuQMjFwlxAmMivf35ZGWPr7+9Xe3i6v13u1g/R0eb1etba2jqiPDz/8UB9//LFmzpwpSTp9+rRCoVBCn3a7XR6PZ8R9AgCA1GZpHZbu7m4NDAzI4XAk7Hc4HDp27NiI+nj66aeVn58fDyihUCjex7V9fvretfr6+tTX1xd/HY1GR3wOAABg/Lmjdwlt3bpVBw8e1M9+9jNlZWWNup9AICC73R7fXC7XGFYJAABMYymw5OTkKCMjQ+FwOGF/OByW0+m84bHbt2/X1q1b9ctf/lILFy6M7//0OCt91tTUKBKJxLezZ89aOQ0AADDOWAosmZmZKi4uVjAYjO8bHBxUMBhUWVnZsMdt27ZNW7ZsUXNzs0pKShLeKywslNPpTOgzGo3qyJEjw/Zps9mUnZ2dsAEAgNRl+VlCfr9fVVVVKikpUWlpqerr69Xb26vq6mpJUmVlpWbNmqVAICBJevHFF1VbW6sDBw6ooKAgPi9l6tSpmjp1qtLS0rR+/Xq98MILmjt3rgoLC/Xcc88pPz9fK1asGLszBQAA45blwFJRUaGuri7V1tYqFArJ7Xarubk5Pmm2s7NT6elXB2527dql/v5+ffOb30zop66uTs8//7wkaePGjert7dWjjz6qnp4eLVmyRM3Nzbc0zwUYT7iFeeK69rPnNmdgaJbXYTER67BgvCOw4FMEFkwkt20dFgAAgGQgsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA41l+lhAA4PYZ6jENLNcPMMICAADGAQILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8VroFkmCo1UwBAMNjhAUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHgEFgAAYDwCCwAAMB4LxwG3GYvE4VZd+//Qma3Lk1QJkDyMsAAAAOMRWAAAgPEILAAAwHijCiwNDQ0qKChQVlaWPB6P2trahm373nvv6eGHH1ZBQYHS0tJUX19/XZvnn39eaWlpCdu8efNGUxoAAEhBlgNLU1OT/H6/6urq1NHRoUWLFsnn8+nixYtDtv/www9VVFSkrVu3yul0DtvvF7/4Rf3hD3+Ib7/+9a+tlgYAAFKU5cCyY8cOrVmzRtXV1Zo/f74aGxs1efJk7dmzZ8j2ixcv1ksvvaSVK1fKZrMN2+9dd90lp9MZ33JycqyWBgAAUpSlwNLf36/29nZ5vd6rHaSny+v1qrW19ZYKOXHihPLz81VUVKRHHnlEnZ2dw7bt6+tTNBpN2AAAQOqyFFi6u7s1MDAgh8ORsN/hcCgUCo26CI/Ho71796q5uVm7du3S6dOndd999+ny5ctDtg8EArLb7fHN5XKN+u8GAADmM+IuoYceekjf+ta3tHDhQvl8Ph0+fFg9PT36l3/5lyHb19TUKBKJxLezZ8/e4YoBAMCdZGml25ycHGVkZCgcDifsD4fDN5xQa9X06dP1p3/6pzp58uSQ79tsthvOhwEAAKnF0ghLZmamiouLFQwG4/sGBwcVDAZVVlY2ZkVduXJFp06dUl5e3pj1CQAAxi/LzxLy+/2qqqpSSUmJSktLVV9fr97eXlVXV0uSKisrNWvWLAUCAUmfTNR9//33438+f/68jh49qqlTp2rOnDmSpA0bNujrX/+6vvCFL+jChQuqq6tTRkaGVq1aNVbnCQAAxjHLgaWiokJdXV2qra1VKBSS2+1Wc3NzfCJuZ2en0tOvDtxcuHBBX/rSl+Kvt2/fru3bt+uBBx5QS0uLJOncuXNatWqVLl26pM9+9rNasmSJ3nnnHX32s5+9xdMDAACpIC0Wi8WSXcStikajstvtikQiys7OTnY5QAKe1oyxxtOakSqsfH8bcZcQAADAjVi+JATgxhhRAYCxxwgLAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4rMMCAOPMUGv9sPotUh0jLAAAwHgEFgAAYDwCCwAAMB6BBQAAGI/AAgAAjEdgAQAAxiOwAAAA47EOCwCkgGvXZmFdFqQaRlgAAIDxCCwAAMB4BBYAAGA8AgsAADAegQUAABiPu4SAWzDUU3MBAGOPERYAAGA8AgsAADAel4QAC7gEBADJwQgLAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjjSqwNDQ0qKCgQFlZWfJ4PGpraxu27XvvvaeHH35YBQUFSktLU319/S33CQC4sYJNb163AeOZ5cDS1NQkv9+vuro6dXR0aNGiRfL5fLp48eKQ7T/88EMVFRVp69atcjqdY9InAACYWCwHlh07dmjNmjWqrq7W/Pnz1djYqMmTJ2vPnj1Dtl+8eLFeeuklrVy5UjabbUz6BAAAE4ulwNLf36/29nZ5vd6rHaSny+v1qrW1dVQFjKbPvr4+RaPRhA0AAKQuS4Glu7tbAwMDcjgcCfsdDodCodCoChhNn4FAQHa7Pb65XK5R/d0AAGB8GJd3CdXU1CgSicS3s2fPJrskAABwG1l6llBOTo4yMjIUDocT9ofD4WEn1N6OPm0227DzYQAAQOqxNMKSmZmp4uJiBYPB+L7BwUEFg0GVlZWNqoDb0ScAAEgtlp/W7Pf7VVVVpZKSEpWWlqq+vl69vb2qrq6WJFVWVmrWrFkKBAKSPplU+/7778f/fP78eR09elRTp07VnDlzRtQnAACY2CwHloqKCnV1dam2tlahUEhut1vNzc3xSbOdnZ1KT786cHPhwgV96Utfir/evn27tm/frgceeEAtLS0j6hNIBhbaAgBzpMVisViyi7hV0WhUdrtdkUhE2dnZyS4HKYLAglR3ZuvyZJeACc7K9/e4vEsIAABMLAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxCCwAAMB4BBYAAGA8AgsAADCe5YcfAqmKZwcBgLkYYQEAAMZjhAUAJqihRhV5gjNMxQgLAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDxuK0ZABB37a3O3OYMUzDCAgAAjEdgAQAAxuOSECYknhsEAOMLIywAAMB4BBYAAGA8AgsAADAegQUAABiPwAIAAIxHYAEAAMYjsAAAAOMRWAAAgPEILAAAwHijCiwNDQ0qKChQVlaWPB6P2trabtj+xz/+sebNm6esrCwtWLBAhw8fTnh/9erVSktLS9jKy8tHUxoAYAwVbHrzug1IBstL8zc1Ncnv96uxsVEej0f19fXy+Xw6fvy4cnNzr2v/m9/8RqtWrVIgENDXvvY1HThwQCtWrFBHR4fuueeeeLvy8nK9+uqr8dc2m22UpwRcj1+yADC+WR5h2bFjh9asWaPq6mrNnz9fjY2Nmjx5svbs2TNk++9///sqLy/Xd77zHf3Zn/2ZtmzZonvvvVc7d+5MaGez2eR0OuPbjBkzRndGAAAg5VgKLP39/Wpvb5fX673aQXq6vF6vWltbhzymtbU1ob0k+Xy+69q3tLQoNzdXd999t9auXatLly4NW0dfX5+i0WjCBgAAUpelwNLd3a2BgQE5HI6E/Q6HQ6FQaMhjQqHQTduXl5frtddeUzAY1Isvvqi3335bDz30kAYGBobsMxAIyG63xzeXy2XlNAAAwDhjeQ7L7bBy5cr4nxcsWKCFCxdq9uzZamlp0YMPPnhd+5qaGvn9/vjraDRKaEEc81UAIPVYGmHJyclRRkaGwuFwwv5wOCyn0znkMU6n01J7SSoqKlJOTo5Onjw55Ps2m03Z2dkJGwAASF2WAktmZqaKi4sVDAbj+wYHBxUMBlVWVjbkMWVlZQntJemtt94atr0knTt3TpcuXVJeXp6V8gAAQIqyfJeQ3+/X7t27tW/fPn3wwQdau3atent7VV1dLUmqrKxUTU1NvP2TTz6p5uZmfe9739OxY8f0/PPP63e/+52eeOIJSdKVK1f0ne98R++8847OnDmjYDCob3zjG5ozZ458Pt8YnSYAABjPLM9hqaioUFdXl2praxUKheR2u9Xc3ByfWNvZ2an09Ks56Ctf+YoOHDigZ599Vs8884zmzp2rQ4cOxddgycjI0Lvvvqt9+/app6dH+fn5WrZsmbZs2cJaLABgoGvniZ3ZujxJlWAiSYvFYrFkF3GrotGo7Ha7IpEI81nApFvgDiOwYLSsfH/zLCEAAGA8AgsAADCeEeuwALeCS0AAkPoILACAWzLUPxqY14KxxiUhAABgPAILAAAwHoEFAAAYj8ACAACMx6RbjCvcEQSMD6yGi7HGCAsAADAegQUAABiPS0IwGpeAAAASIywAAGAcYIQFAHDbsRoubhUjLAAAwHgEFgAAYDwCCwAAMB5zWGAM7ggCAAyHwAIASApWw4UVXBICAADGI7AAAADjEVgAAIDxmMOCpGGSLQBgpBhhAQAAxmOEBXcEoykAbmYkvye4k2jiIrAAAMYNnkk0cRFYcFswogIAGEvMYQEAAMZjhAUAMK6xYu7EQGDBLePyDwCTMM8lNRFYYBkBBQBwpzGHBQAAGG9UgaWhoUEFBQXKysqSx+NRW1vbDdv/+Mc/1rx585SVlaUFCxbo8OHDCe/HYjHV1tYqLy9PkyZNktfr1YkTJ0ZTGsZYwaY3r9sAYLzjd9v4Y/mSUFNTk/x+vxobG+XxeFRfXy+fz6fjx48rNzf3uva/+c1vtGrVKgUCAX3ta1/TgQMHtGLFCnV0dOiee+6RJG3btk0vv/yy9u3bp8LCQj333HPy+Xx6//33lZWVdetniSHxAwpgouD33fiXFovFYlYO8Hg8Wrx4sXbu3ClJGhwclMvl0rp167Rp06br2ldUVKi3t1c///nP4/u+/OUvy+12q7GxUbFYTPn5+Xrqqae0YcMGSVIkEpHD4dDevXu1cuXKm9YUjUZlt9sViUSUnZ1t5XQmFH5gAeDWMHl3bFn5/rY0wtLf36/29nbV1NTE96Wnp8vr9aq1tXXIY1pbW+X3+xP2+Xw+HTp0SJJ0+vRphUIheb3e+Pt2u10ej0etra1DBpa+vj719fXFX0ciEUmfnPh4d0/dL0Z13O83+8akHwDA8D7//3580zbX/j7G8D793h7J2ImlwNLd3a2BgQE5HI6E/Q6HQ8eOHRvymFAoNGT7UCgUf//TfcO1uVYgENDmzZuv2+9yuUZ2IinIXp/sCgAAEr+PR+Py5cuy2+03bDMub2uuqalJGLUZHBzU//7v/+ozn/mM0tLS7mgt0WhULpdLZ8+e5XKUwficzMdnZD4+I/ONt88oFovp8uXLys/Pv2lbS4ElJydHGRkZCofDCfvD4bCcTueQxzidzhu2//S/4XBYeXl5CW3cbveQfdpsNtlstoR906dPt3IqYy47O3tc/M8x0fE5mY/PyHx8RuYbT5/RzUZWPmXptubMzEwVFxcrGAzG9w0ODioYDKqsrGzIY8rKyhLaS9Jbb70Vb19YWCin05nQJhqN6siRI8P2CQAAJhbLl4T8fr+qqqpUUlKi0tJS1dfXq7e3V9XV1ZKkyspKzZo1S4FAQJL05JNP6oEHHtD3vvc9LV++XAcPHtTvfvc7/cM//IMkKS0tTevXr9cLL7yguXPnxm9rzs/P14oVK8buTAEAwLhlObBUVFSoq6tLtbW1CoVCcrvdam5ujk+a7ezsVHr61YGbr3zlKzpw4ICeffZZPfPMM5o7d64OHToUX4NFkjZu3Kje3l49+uij6unp0ZIlS9Tc3Dwu1mCx2Wyqq6u77hIVzMLnZD4+I/PxGZkvlT8jy+uwAAAA3Gk8SwgAABiPwAIAAIxHYAEAAMYjsAAAAOMRWG6Tvr4+ud1upaWl6ejRo8kuB3905swZffvb31ZhYaEmTZqk2bNnq66uTv39/ckubUJraGhQQUGBsrKy5PF41NbWluyS8EeBQECLFy/WtGnTlJubqxUrVuj48ePJLgs3sHXr1viSIamEwHKbbNy4cURLDePOOnbsmAYHB/XDH/5Q7733nv7+7/9ejY2NeuaZZ5Jd2oTV1NQkv9+vuro6dXR0aNGiRfL5fLp48WKyS4Okt99+W48//rjeeecdvfXWW/r444+1bNky9fb2Jrs0DOG3v/2tfvjDH2rhwoXJLmXMcVvzbfDv//7v8vv9+td//Vd98Ytf1H/+538O+5gBJN9LL72kXbt26b//+7+TXcqE5PF4tHjxYu3cuVPSJ6tnu1wurVu3Tps2bUpydbhWV1eXcnNz9fbbb+v+++9Pdjn4P65cuaJ7771XP/jBD/TCCy/I7Xarvr4+2WWNGUZYxlg4HNaaNWu0f/9+TZ48OdnlYAQikYhmzpyZ7DImpP7+frW3t8vr9cb3paeny+v1qrW1NYmVYTiRSESS+Jkx0OOPP67ly5cn/DylknH5tGZTxWIxrV69Wn/7t3+rkpISnTlzJtkl4SZOnjypV155Rdu3b092KRNSd3e3BgYG4itlf8rhcOjYsWNJqgrDGRwc1Pr16/Xnf/7nCauVI/kOHjyojo4O/fa3v012KbcNIywjsGnTJqWlpd1wO3bsmF555RVdvnxZNTU1yS55whnpZ/R/nT9/XuXl5frWt76lNWvWJKlyYPx4/PHH9fvf/14HDx5Mdin4P86ePasnn3xSP/rRj8bFI21GizksI9DV1aVLly7dsE1RUZH++q//Wv/2b/+mtLS0+P6BgQFlZGTokUce0b59+253qRPWSD+jzMxMSdKFCxe0dOlSffnLX9bevXsTnn+FO6e/v1+TJ0/WT37yk4SHnVZVVamnp0dvvPFG8opDgieeeEJvvPGG/uM//kOFhYXJLgf/x6FDh/SXf/mXysjIiO8bGBhQWlqa0tPT1dfXl/DeeEVgGUOdnZ2KRqPx1xcuXJDP59NPfvITeTwefe5zn0tidfjU+fPn9Rd/8RcqLi7W66+/nhI/yOOZx+NRaWmpXnnlFUmfXHb4/Oc/ryeeeIJJtwaIxWJat26dfvazn6mlpUVz585Ndkm4xuXLl/U///M/Cfuqq6s1b948Pf300ylz+Y45LGPo85//fMLrqVOnSpJmz55NWDHE+fPntXTpUn3hC1/Q9u3b1dXVFX/P6XQmsbKJy+/3q6qqSiUlJSotLVV9fb16e3tVXV2d7NKgTy4DHThwQG+88YamTZumUCgkSbLb7Zo0aVKSq4MkTZs27bpQMmXKFH3mM59JmbAiEVgwwbz11ls6efKkTp48eV2IZLAxOSoqKtTV1aXa2lqFQiG53W41NzdfNxEXybFr1y5J0tKlSxP2v/rqq1q9evWdLwgTFpeEAACA8ZhpCAAAjEdgAQAAxiOwAAAA4xFYAACA8QgsAADAeAQWAABgPAILAAAwHoEFAAAYj8ACAACMR2ABAADGI7AAAADjEVgAAIDx/j9DcgVjGdibUgAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# generate 100000 samples from a 1D Gaussian\n",
    "\n",
    "mu = torch.tensor([0.0])\n",
    "sigma = torch.tensor([1.0])\n",
    "normal = Normal(mu, sigma)\n",
    "normal_samples = normal.sample((100000,))\n",
    "\n",
    "# plot the histogram of the samples\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.hist(normal_samples.numpy(), bins=100, density=True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAkD0lEQVR4nO3de3CU1f3H8U8SyAaEDdBIIhiJgAWRS4CYGKaKrVuDTb2MdhrQkZhabOu1XbUmXhKB2kShmBlNxTKAHVsL6nibitGakWnVKBpAEZARCoLaXYjoBoIGyZ7fH/5YXXNhn2STk13er5mdYc+e59nz3bO7z4ezz24SjDFGAAAAliTaHgAAADi+EUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWNXP9gAiEQwG9cknn2jw4MFKSEiwPRwAABABY4wOHDigESNGKDGx4/WPmAgjn3zyiTIzM20PAwAAdMGePXt08sknd3h7TISRwYMHS/q6GLfbbXk0AAAgEk1NTcrMzAwdxzsSE2Hk6EczbrebMAIAQIw51ikWnMAKAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACr+tkeAIDIZZU+H3Z9V1WhpZEAQPSwMgIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKn5nBN323d++kPj9CwBA5FgZAQAAVhFGAACAVYQRAABgFeeMoFOcDxL7mEP0NJ5j6C5WRgAAgFWEEQAAYBUf0wCd+O7yM0vPABB9rIwAAACrWBkBEDdYyQJiEysjAADAKlZG4hRftYMTPF8A2MTKCAAAsIqVkRjE/2KByPBaQV/FczMcYQQA0OM4+KIzhBEAfQ4HLuD4QhgB0C6+Jou+6nh5bnY1lMfi40MYAeJMe29gvXVfsfCmh29E62DXF/HcjC2EkeOIzRcny+7oLg4uQPwijOC4xcENAPqGLoWRmpoaLVq0SD6fT1OmTNEDDzyg3Nzcdvs+8sgjKikpCWtzuVz68ssvu3LXfZbt//l3ZdnU9pgjEUldfW3MsSAWltn7klh4rURLXwvpsfhcjeePv3qK4zCyevVqeb1eLV26VHl5eaqurlZBQYG2bdum4cOHt7uN2+3Wtm3bQtcTEhK6PmIgDhxPB7e+jrlAZ/paOItXjsPIkiVLNG/evNBqx9KlS/X8889rxYoVKi0tbXebhIQEZWRkdG+kcSBe0nJfGw8QK2wHn7722uVAj6MchZHDhw+roaFBZWVlobbExER5PB7V19d3uN3Bgwc1atQoBYNBTZs2TX/84x91xhlndNi/paVFLS0toetNTU1OhtltvEDgBB8lxZZ4PSDzvmVPX3tOxSJHYaSxsVGtra1KT08Pa09PT9f777/f7jbjxo3TihUrNHnyZAUCAS1evFgzZszQ5s2bdfLJJ7e7TWVlpebPn+9kaMBxKRbeBDlIoi+wvSqFzvX4t2ny8/OVn58fuj5jxgydfvrpevjhh7Vw4cJ2tykrK5PX6w1db2pqUmZmZk8PFVFk87cupPh4k4mFoGETj0/fxvyE4/HonKMwkpaWpqSkJPn9/rB2v98f8Tkh/fv319SpU7V9+/YO+7hcLrlcLidDi1n8rxHHm3gNjwC6LtFJ5+TkZE2fPl11dXWhtmAwqLq6urDVj860trZq06ZNOumkk5yNFAAAxCXHH9N4vV4VFxcrJydHubm5qq6uVnNzc+jbNXPnztXIkSNVWVkpSVqwYIHOOussjR07Vp9//rkWLVqkDz/8UL/85S+jWwliDqtCAKKBj0Bin+MwUlRUpH379qm8vFw+n0/Z2dmqra0NndS6e/duJSZ+s+Dy2Wefad68efL5fBo6dKimT5+u119/XRMmTIheFYgLvKEAwPGpSyewXn/99br++uvbvW3t2rVh1++//37df//9XbmbPiNefh/keMKqC7ojXn/RGOir+Ns0wP+L5ADU1YAZD8G0J2uIh8cnWo7nUMPzwLl4ecwII31MvDyxAACIFGEEAPo4PnZEvCOMAECMsf2jgkC0EUYAAMclVpz6DsIIAHQBKwZA9BBGcFzgwAHgeBULf1nc0c/BAwAARBsrIwAQh1gNjE/xOq+EkS6K1ycEAByvevKHD9E5wgjQR8TDm1w81ACg93HOCAAAsIowAgAArOJjGoRhmR0A0NtYGQEAAFYd9ysjrAQA6Cm8vwCRYWUEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABY1aUwUlNTo6ysLKWkpCgvL0/r1q2LaLtVq1YpISFBl1xySVfuFgAAxCHHYWT16tXyer2qqKjQ+vXrNWXKFBUUFGjv3r2dbrdr1y7dcsstOvvss7s8WAAAEH8ch5ElS5Zo3rx5Kikp0YQJE7R06VINHDhQK1as6HCb1tZWXXHFFZo/f75Gjx7drQEDAID44iiMHD58WA0NDfJ4PN/sIDFRHo9H9fX1HW63YMECDR8+XFdffXVE99PS0qKmpqawCwAAiE+OwkhjY6NaW1uVnp4e1p6eni6fz9fuNq+++qqWL1+uZcuWRXw/lZWVSk1NDV0yMzOdDBMAAMSQHv02zYEDB3TllVdq2bJlSktLi3i7srIyBQKB0GXPnj09OEoAAGBTPyed09LSlJSUJL/fH9bu9/uVkZHRpv+OHTu0a9cuXXjhhaG2YDD49R3366dt27ZpzJgxbbZzuVxyuVxOhgYAAGKUo5WR5ORkTZ8+XXV1daG2YDCouro65efnt+k/fvx4bdq0SRs3bgxdLrroIv3whz/Uxo0b+fgFAAA4WxmRJK/Xq+LiYuXk5Cg3N1fV1dVqbm5WSUmJJGnu3LkaOXKkKisrlZKSookTJ4ZtP2TIEElq0w4AAI5PjsNIUVGR9u3bp/Lycvl8PmVnZ6u2tjZ0Uuvu3buVmMgPuwIAgMgkGGOM7UEcS1NTk1JTUxUIBOR2u6O676zS56O6PwAAYs2uqsIe2W+kx2+WMAAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVV0KIzU1NcrKylJKSory8vK0bt26Dvs+9dRTysnJ0ZAhQ3TCCScoOztbjz76aJcHDAAA4ovjMLJ69Wp5vV5VVFRo/fr1mjJligoKCrR37952+w8bNkx33HGH6uvr9e6776qkpEQlJSV68cUXuz14AAAQ+xKMMcbJBnl5eTrzzDP14IMPSpKCwaAyMzN1ww03qLS0NKJ9TJs2TYWFhVq4cGFE/ZuampSamqpAICC32+1kuMeUVfp8VPcHAECs2VVV2CP7jfT47Whl5PDhw2poaJDH4/lmB4mJ8ng8qq+vP+b2xhjV1dVp27ZtOuecczrs19LSoqamprALAACIT47CSGNjo1pbW5Wenh7Wnp6eLp/P1+F2gUBAgwYNUnJysgoLC/XAAw/oxz/+cYf9KysrlZqaGrpkZmY6GSYAAIghvfJtmsGDB2vjxo166623dM8998jr9Wrt2rUd9i8rK1MgEAhd9uzZ0xvDBAAAFvRz0jktLU1JSUny+/1h7X6/XxkZGR1ul5iYqLFjx0qSsrOztXXrVlVWVurcc89tt7/L5ZLL5XIyNAAAEKMcrYwkJydr+vTpqqurC7UFg0HV1dUpPz8/4v0Eg0G1tLQ4uWsAABCnHK2MSJLX61VxcbFycnKUm5ur6upqNTc3q6SkRJI0d+5cjRw5UpWVlZK+Pv8jJydHY8aMUUtLi9asWaNHH31UDz30UHQrAQAAMclxGCkqKtK+fftUXl4un8+n7Oxs1dbWhk5q3b17txITv1lwaW5u1rXXXquPPvpIAwYM0Pjx4/W3v/1NRUVF0asCAADELMe/M2IDvzMCAEDPianfGQEAAIg2wggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwKouhZGamhplZWUpJSVFeXl5WrduXYd9ly1bprPPPltDhw7V0KFD5fF4Ou0PAACOL47DyOrVq+X1elVRUaH169drypQpKigo0N69e9vtv3btWs2ZM0evvPKK6uvrlZmZqfPPP18ff/xxtwcPAABiX4IxxjjZIC8vT2eeeaYefPBBSVIwGFRmZqZuuOEGlZaWHnP71tZWDR06VA8++KDmzp0b0X02NTUpNTVVgUBAbrfbyXCPKav0+ajuDwCAWLOrqrBH9hvp8dvRysjhw4fV0NAgj8fzzQ4SE+XxeFRfXx/RPg4dOqSvvvpKw4YN67BPS0uLmpqawi4AACA+OQojjY2Nam1tVXp6elh7enq6fD5fRPu47bbbNGLEiLBA812VlZVKTU0NXTIzM50MEwAAxJBe/TZNVVWVVq1apaefflopKSkd9isrK1MgEAhd9uzZ04ujBAAAvamfk85paWlKSkqS3+8Pa/f7/crIyOh028WLF6uqqkovv/yyJk+e3Glfl8sll8vlZGgAACBGOVoZSU5O1vTp01VXVxdqCwaDqqurU35+fofb3XfffVq4cKFqa2uVk5PT9dECAIC442hlRJK8Xq+Ki4uVk5Oj3NxcVVdXq7m5WSUlJZKkuXPnauTIkaqsrJQk3XvvvSovL9djjz2mrKys0LklgwYN0qBBg6JYCgAAiEWOw0hRUZH27dun8vJy+Xw+ZWdnq7a2NnRS6+7du5WY+M2Cy0MPPaTDhw/rZz/7Wdh+KioqdPfdd3dv9AAAIOY5/p0RG/idEQAAek5M/c4IAABAtBFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFWEEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWdSmM1NTUKCsrSykpKcrLy9O6des67Lt582ZddtllysrKUkJCgqqrq7s6VgAAEIcch5HVq1fL6/WqoqJC69ev15QpU1RQUKC9e/e22//QoUMaPXq0qqqqlJGR0e0BAwCA+OI4jCxZskTz5s1TSUmJJkyYoKVLl2rgwIFasWJFu/3PPPNMLVq0SLNnz5bL5er2gAEAQHxxFEYOHz6shoYGeTyeb3aQmCiPx6P6+vqoDaqlpUVNTU1hFwAAEJ8chZHGxka1trYqPT09rD09PV0+ny9qg6qsrFRqamrokpmZGbV9AwCAvqVPfpumrKxMgUAgdNmzZ4/tIQEAgB7Sz0nntLQ0JSUlye/3h7X7/f6onpzqcrk4vwQAgOOEo5WR5ORkTZ8+XXV1daG2YDCouro65efnR31wAAAg/jlaGZEkr9er4uJi5eTkKDc3V9XV1WpublZJSYkkae7cuRo5cqQqKyslfX3S65YtW0L//vjjj7Vx40YNGjRIY8eOjWIpAAAgFjkOI0VFRdq3b5/Ky8vl8/mUnZ2t2tra0Emtu3fvVmLiNwsun3zyiaZOnRq6vnjxYi1evFgzZ87U2rVru18BAACIaQnGGGN7EMfS1NSk1NRUBQIBud3uqO47q/T5qO4PAIBYs6uqsEf2G+nxu09+mwYAABw/CCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKsIIwAAwCrCCAAAsIowAgAArCKMAAAAqwgjAADAKsIIAACwijACAACsIowAAACrCCMAAMAqwggAALCKMAIAAKwijAAAAKu6FEZqamqUlZWllJQU5eXlad26dZ32f+KJJzR+/HilpKRo0qRJWrNmTZcGCwAA4o/jMLJ69Wp5vV5VVFRo/fr1mjJligoKCrR37952+7/++uuaM2eOrr76am3YsEGXXHKJLrnkEr333nvdHjwAAIh9CcYY42SDvLw8nXnmmXrwwQclScFgUJmZmbrhhhtUWlrapn9RUZGam5v1z3/+M9R21llnKTs7W0uXLo3oPpuampSamqpAICC32+1kuMeUVfp8VPcHAECs2VVV2CP7jfT43c/JTg8fPqyGhgaVlZWF2hITE+XxeFRfX9/uNvX19fJ6vWFtBQUFeuaZZzq8n5aWFrW0tISuBwIBSV8XFW3BlkNR3ycAALGkJ46v397vsdY9HIWRxsZGtba2Kj09Paw9PT1d77//frvb+Hy+dvv7fL4O76eyslLz589v056ZmelkuAAAIAKp1T27/wMHDig1NbXD2x2Fkd5SVlYWtpoSDAa1f/9+fe9731NCQkLU7qepqUmZmZnas2dP1D/+6SvivUbqi33xXiP1xb54r7En6zPG6MCBAxoxYkSn/RyFkbS0NCUlJcnv94e1+/1+ZWRktLtNRkaGo/6S5HK55HK5wtqGDBniZKiOuN3uuHyCfVu810h9sS/ea6S+2BfvNfZUfZ2tiBzl6Ns0ycnJmj59uurq6kJtwWBQdXV1ys/Pb3eb/Pz8sP6S9K9//avD/gAA4Pji+GMar9er4uJi5eTkKDc3V9XV1WpublZJSYkkae7cuRo5cqQqKyslSTfddJNmzpypP/3pTyosLNSqVav09ttv6y9/+Ut0KwEAADHJcRgpKirSvn37VF5eLp/Pp+zsbNXW1oZOUt29e7cSE79ZcJkxY4Yee+wx3Xnnnbr99tt12mmn6ZlnntHEiROjV0UXuVwuVVRUtPlIKJ7Ee43UF/vivUbqi33xXmNfqM/x74wAAABEE3+bBgAAWEUYAQAAVhFGAACAVYQRAABgVdyHkXvuuUczZszQwIEDI/7hNGOMysvLddJJJ2nAgAHyeDz64IMPwvrs379fV1xxhdxut4YMGaKrr75aBw8e7IEKOud0HLt27VJCQkK7lyeeeCLUr73bV61a1RslhenK43zuuee2Gfuvf/3rsD67d+9WYWGhBg4cqOHDh+vWW2/VkSNHerKUDjmtcf/+/brhhhs0btw4DRgwQKeccopuvPHG0N9wOsrWHNbU1CgrK0spKSnKy8vTunXrOu3/xBNPaPz48UpJSdGkSZO0Zs2asNsjeT32Nic1Llu2TGeffbaGDh2qoUOHyuPxtOl/1VVXtZmrWbNm9XQZHXJS3yOPPNJm7CkpKWF9+tocOqmvvfeThIQEFRZ+84fl+tL8/fvf/9aFF16oESNGKCEhodO/A3fU2rVrNW3aNLlcLo0dO1aPPPJImz5OX9eOmThXXl5ulixZYrxer0lNTY1om6qqKpOammqeeeYZ884775iLLrrInHrqqeaLL74I9Zk1a5aZMmWKeeONN8x//vMfM3bsWDNnzpweqqJjTsdx5MgR87///S/sMn/+fDNo0CBz4MCBUD9JZuXKlWH9vl1/b+nK4zxz5kwzb968sLEHAoHQ7UeOHDETJ040Ho/HbNiwwaxZs8akpaWZsrKyni6nXU5r3LRpk7n00kvNc889Z7Zv327q6urMaaedZi677LKwfjbmcNWqVSY5OdmsWLHCbN682cybN88MGTLE+P3+dvu/9tprJikpydx3331my5Yt5s477zT9+/c3mzZtCvWJ5PXYm5zWePnll5uamhqzYcMGs3XrVnPVVVeZ1NRU89FHH4X6FBcXm1mzZoXN1f79+3urpDBO61u5cqVxu91hY/f5fGF9+tIcOq3v008/DavtvffeM0lJSWblypWhPn1p/tasWWPuuOMO89RTTxlJ5umnn+60/3//+18zcOBA4/V6zZYtW8wDDzxgkpKSTG1tbaiP08esK+I+jBy1cuXKiMJIMBg0GRkZZtGiRaG2zz//3LhcLvOPf/zDGGPMli1bjCTz1ltvhfq88MILJiEhwXz88cdRH3tHojWO7Oxs84tf/CKsLZIncU/ran0zZ840N910U4e3r1mzxiQmJoa9YT700EPG7XablpaWqIw9UtGaw8cff9wkJyebr776KtRmYw5zc3PNddddF7re2tpqRowYYSorK9vt//Of/9wUFhaGteXl5Zlf/epXxpjIXo+9zWmN33XkyBEzePBg89e//jXUVlxcbC6++OJoD7VLnNZ3rPfWvjaH3Z2/+++/3wwePNgcPHgw1NaX5u/bInkP+P3vf2/OOOOMsLaioiJTUFAQut7dxywScf8xjVM7d+6Uz+eTx+MJtaWmpiovL0/19fWSpPr6eg0ZMkQ5OTmhPh6PR4mJiXrzzTd7bazRGEdDQ4M2btyoq6++us1t1113ndLS0pSbm6sVK1Yc809AR1t36vv73/+utLQ0TZw4UWVlZTp06FDYfidNmhT216QLCgrU1NSkzZs3R7+QTkTruRQIBOR2u9WvX/jvGPbmHB4+fFgNDQ1hr53ExER5PJ7Qa+e76uvrw/pLX8/F0f6RvB57U1dq/K5Dhw7pq6++0rBhw8La165dq+HDh2vcuHH6zW9+o08//TSqY49EV+s7ePCgRo0apczMTF188cVhr6O+NIfRmL/ly5dr9uzZOuGEE8La+8L8dcWxXoPReMwi0Sf/aq9NPp9PksIOVEevH73N5/Np+PDhYbf369dPw4YNC/XpDdEYx/Lly3X66adrxowZYe0LFizQj370Iw0cOFAvvfSSrr32Wh08eFA33nhj1MZ/LF2t7/LLL9eoUaM0YsQIvfvuu7rtttu0bds2PfXUU6H9tje/R2/rTdGYw8bGRi1cuFDXXHNNWHtvz2FjY6NaW1vbfWzff//9drfpaC6+/Vo72tZRn97UlRq/67bbbtOIESPC3txnzZqlSy+9VKeeeqp27Nih22+/XRdccIHq6+uVlJQU1Ro605X6xo0bpxUrVmjy5MkKBAJavHixZsyYoc2bN+vkk0/uU3PY3flbt26d3nvvPS1fvjysva/MX1d09BpsamrSF198oc8++6zbz/lIxGQYKS0t1b333ttpn61bt2r8+PG9NKLoirS+7vriiy/02GOP6a677mpz27fbpk6dqubmZi1atCgqB7Keru/bB+VJkybppJNO0nnnnacdO3ZozJgxXd6vE701h01NTSosLNSECRN09913h93Wk3OIrqmqqtKqVau0du3asJM8Z8+eHfr3pEmTNHnyZI0ZM0Zr167VeeedZ2OoEcvPzw/7w6czZszQ6aefrocfflgLFy60OLLoW758uSZNmqTc3Nyw9liev74iJsPIzTffrKuuuqrTPqNHj+7SvjMyMiRJfr9fJ510Uqjd7/crOzs71Gfv3r1h2x05ckT79+8Pbd8dkdbX3XE8+eSTOnTokObOnXvMvnl5eVq4cKFaWlq6/fcLequ+o/Ly8iRJ27dv15gxY5SRkdHmTHC/3y9JUZk/qXdqPHDggGbNmqXBgwfr6aefVv/+/TvtH805bE9aWpqSkpJCj+VRfr+/w1oyMjI67R/J67E3daXGoxYvXqyqqiq9/PLLmjx5cqd9R48erbS0NG3fvr1XD2bdqe+o/v37a+rUqdq+fbukvjWH3amvublZq1at0oIFC455P7bmrys6eg263W4NGDBASUlJ3X5ORCRqZ5/0cU5PYF28eHGoLRAItHsC69tvvx3q8+KLL1o7gbWr45g5c2abb2B05A9/+IMZOnRol8faFdF6nF999VUjybzzzjvGmG9OYP32meAPP/ywcbvd5ssvv4xeARHoao2BQMCcddZZZubMmaa5uTmi++qNOczNzTXXX3996Hpra6sZOXJkpyew/vSnPw1ry8/Pb3MCa2evx97mtEZjjLn33nuN2+029fX1Ed3Hnj17TEJCgnn22We7PV6nulLftx05csSMGzfO/O53vzPG9L057Gp9K1euNC6XyzQ2Nh7zPmzO37cpwhNYJ06cGNY2Z86cNiewduc5EdFYo7anPurDDz80GzZsCH19dcOGDWbDhg1hX2MdN26ceeqpp0LXq6qqzJAhQ8yzzz5r3n33XXPxxRe3+9XeqVOnmjfffNO8+uqr5rTTTrP21d7OxvHRRx+ZcePGmTfffDNsuw8++MAkJCSYF154oc0+n3vuObNs2TKzadMm88EHH5g///nPZuDAgaa8vLzH6/kup/Vt377dLFiwwLz99ttm586d5tlnnzWjR48255xzTmibo1/tPf/8883GjRtNbW2tOfHEE61+tddJjYFAwOTl5ZlJkyaZ7du3h32d8MiRI8YYe3O4atUq43K5zCOPPGK2bNlirrnmGjNkyJDQN5euvPJKU1paGur/2muvmX79+pnFixebrVu3moqKina/2nus12NvclpjVVWVSU5ONk8++WTYXB19Dzpw4IC55ZZbTH19vdm5c6d5+eWXzbRp08xpp53W6+G4K/XNnz/fvPjii2bHjh2moaHBzJ4926SkpJjNmzeH+vSlOXRa31E/+MEPTFFRUZv2vjZ/Bw4cCB3nJJklS5aYDRs2mA8//NAYY0xpaam58sorQ/2PfrX31ltvNVu3bjU1NTXtfrW3s8csGuI+jBQXFxtJbS6vvPJKqI/+//cYjgoGg+auu+4y6enpxuVymfPOO89s27YtbL+ffvqpmTNnjhk0aJBxu92mpKQkLOD0lmONY+fOnW3qNcaYsrIyk5mZaVpbW9vs84UXXjDZ2dlm0KBB5oQTTjBTpkwxS5cubbdvT3Na3+7du80555xjhg0bZlwulxk7dqy59dZbw35nxBhjdu3aZS644AIzYMAAk5aWZm6++eawr8X2Jqc1vvLKK+0+pyWZnTt3GmPszuEDDzxgTjnlFJOcnGxyc3PNG2+8Ebpt5syZpri4OKz/448/br7//e+b5ORkc8YZZ5jnn38+7PZIXo+9zUmNo0aNaneuKioqjDHGHDp0yJx//vnmxBNPNP379zejRo0y8+bNi+obvVNO6vvtb38b6puenm5+8pOfmPXr14ftr6/NodPn6Pvvv28kmZdeeqnNvvra/HX0/nC0puLiYjNz5sw222RnZ5vk5GQzevTosOPhUZ09ZtGQYEwvf18TAADgW/idEQAAYBVhBAAAWEUYAQAAVhFGAACAVYQRAABgFWEEAABYRRgBAABWEUYAAIBVhBEAAGAVYQQAAFhFGAEAAFYRRgAAgFX/B8EJ4nj3urA7AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# create a uniform distribution\n",
    "\n",
    "uniform = Uniform(-1.0, 1.0)\n",
    "uniform_samples = uniform.sample((100000,))\n",
    "uniform_samples = uniform_samples.unsqueeze(1)\n",
    "plt.hist(uniform_samples.numpy(), bins=100, density=True)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([100000, 2])\n"
     ]
    }
   ],
   "source": [
    "# create a train dataloader from the samples\n",
    "\n",
    "dataset = torch.cat([normal_samples, uniform_samples], dim=-1)\n",
    "trainloader = data.DataLoader(dataset, batch_size=64, shuffle=True)\n",
    "print(dataset.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "216645\n",
      "NormalizingFlow(\n",
      "  (transform): ComposedTransform(\n",
      "    (0): DependentTransform(MonotonicRQSTransform(bins=8), 1)\n",
      "    (1): DependentTransform(MonotonicRQSTransform(bins=8), 1)\n",
      "    (2): DependentTransform(MonotonicRQSTransform(bins=8), 1)\n",
      "  )\n",
      "  (base): DiagNormal(loc: tensor([[0.]]), scale: tensor([[1.]]))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "flow = zuko.flows.NSF(features=1, context=1, transforms=3, hidden_features=(256, 256))\n",
    "# print number of parameters\n",
    "print(sum(p.numel() for p in flow.parameters()))\n",
    "print(flow(torch.randn(1, 1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0) 0.7525021433830261 ± 0.0751950815320015\n",
      "(1) 0.7175981998443604 ± 0.025762757286429405\n",
      "(2) 0.7160289883613586 ± 0.024040669202804565\n",
      "(3) 0.7120639085769653 ± 0.022832900285720825\n",
      "(4) 0.7107819318771362 ± 0.021069666370749474\n",
      "(5) 0.7106785178184509 ± 0.02300257794559002\n",
      "(6) 0.7086082696914673 ± 0.02078089490532875\n",
      "(7) 0.7073511481285095 ± 0.018751656636595726\n"
     ]
    }
   ],
   "source": [
    "optimizer = torch.optim.Adam(flow.parameters(), lr=1e-3)\n",
    "\n",
    "for epoch in range(8):\n",
    "    losses = []\n",
    "\n",
    "    for x in trainloader:\n",
    "        x, y = x.split([1, 1], dim=-1)\n",
    "        #y_hat = flow(x)\n",
    "        loss = -flow(x).log_prob(y).mean()\n",
    "        #loss = -flow().log_prob(x).mean()\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        losses.append(loss.detach())\n",
    "\n",
    "    losses = torch.stack(losses)\n",
    "\n",
    "    print(f\"({epoch})\", losses.mean().item(), \"±\", losses.std().item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1000000, 1])\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAiMAAAGdCAYAAADAAnMpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAak0lEQVR4nO3df2xd913/8ZedErtZYrddNJtm3txfIlRj8YhjN0WjnWRmpAgI4keYJpJZJRIiqzpZiCUDxXQVcthCidRGTakoQy1VIibWCloFKosOoRmlJIvYCqlUpDRZgh0HmB1cyZ7s+/1jm/v1krS+aZKPfzwe0pV2j8+59311l/qpzz33uKZSqVQCAFBIbekBAIClTYwAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRN5QeYC6mp6dz9uzZrFq1KjU1NaXHAQDmoFKp5MKFC7n11ltTW3v59Y8FESNnz55NS0tL6TEAgCtw+vTpfPCDH7zszxdEjKxatSrJ919MQ0ND4WkAgLkYGxtLS0vLzO/xy1kQMfLDj2YaGhrECAAsMO92ioUTWAGAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARd1QegBgcWvd+eKs+yf3bCo0CTBfWRkBAIoSIwBAUWIEACjKOSPAdfWj55AkziOBpc7KCABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoX+0FFqRLfUX4R/nKMCwMYgQozt+vgaXNxzQAQFFiBAAoSowAAEWJEQCgKDECABTl2zTAVTOXr9vOx8cGyhIjwBUTCMDV4GMaAKAoKyMA85yLwrHYiRFg0brUx0h+kcP8I0aAece5KLC0OGcEAChKjAAARV1RjOzfvz+tra2pr69PZ2dnjhw5ctl9v/KVr6SmpmbWrb6+/ooHBgAWl6pj5NChQ+nt7U1fX1+OHTuWdevWpbu7O+fOnbvsMQ0NDfmv//qvmdubb775noYGABaPqmPk0Ucfzfbt29PT05O77747Bw4cyIoVK/L0009f9piampo0NzfP3Jqamt7T0ADA4lFVjExOTubo0aPp6up6+wFqa9PV1ZXBwcHLHvd///d/+fCHP5yWlpb80i/9Ul577bUrnxgAWFSqipHz589namrqopWNpqamDA0NXfKYn/iJn8jTTz+dF154Ic8++2ymp6dz77335jvf+c5ln2diYiJjY2OzbgDA4nTNrzOycePGbNy4ceb+vffem5/8yZ/Mk08+mUceeeSSx/T39+fhhx++1qMBS5CrmcL8U9XKyOrVq7Ns2bIMDw/P2j48PJzm5uY5PcaP/diP5WMf+1jeeOONy+6za9eujI6OztxOnz5dzZgAwAJSVYwsX74869evz8DAwMy26enpDAwMzFr9eCdTU1P51re+lR//8R+/7D51dXVpaGiYdQMAFqeqP6bp7e3Ntm3b0t7eno6Ojuzbty/j4+Pp6elJkmzdujVr1qxJf39/kuSLX/xi7rnnntx555357ne/my9/+ct5880381u/9VtX95UAAAtS1TGyZcuWjIyMZPfu3RkaGkpbW1sOHz48c1LrqVOnUlv79oLL//7v/2b79u0ZGhrKzTffnPXr1+cb3/hG7r777qv3KgCABaumUqlUSg/xbsbGxtLY2JjR0VEf2cA8shj+oN1COIHVSbcsVHP9/e1v0wAARV3zr/YCcHlWPcDKCABQmBgBAIryMQ2wpF3qJFwflcD1ZWUEACjKygjAPLIYvi4N1bIyAgAUJUYAgKLECABQlBgBAIpyAivAj3BVVLi+xAjAAuPaKCw2PqYBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICiXGcE4F3M5S/pus4HXDkrIwBAUVZGAK4CV0WFK2dlBAAoysoIwDXiD+7B3IgRgOtkLifCwlLkYxoAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQ1BXFyP79+9Pa2pr6+vp0dnbmyJEjczru4MGDqampyebNm6/kaQGARajqGDl06FB6e3vT19eXY8eOZd26denu7s65c+fe8biTJ0/md3/3d/Pxj3/8iocFABafqmPk0Ucfzfbt29PT05O77747Bw4cyIoVK/L0009f9pipqal8+tOfzsMPP5zbb7/9PQ0MACwuVcXI5ORkjh49mq6urrcfoLY2XV1dGRwcvOxxX/ziF/OBD3wgDzzwwJyeZ2JiImNjY7NuAMDiVFWMnD9/PlNTU2lqapq1vampKUNDQ5c85p//+Z/z53/+53nqqafm/Dz9/f1pbGycubW0tFQzJgCwgFzTb9NcuHAhv/mbv5mnnnoqq1evnvNxu3btyujo6Mzt9OnT13BKAKCkG6rZefXq1Vm2bFmGh4dnbR8eHk5zc/NF+//nf/5nTp48mV/4hV+Y2TY9Pf39J77hhrz++uu54447Ljqurq4udXV11YwGACxQVa2MLF++POvXr8/AwMDMtunp6QwMDGTjxo0X7b927dp861vfyvHjx2duv/iLv5hPfOITOX78uI9fAIDqVkaSpLe3N9u2bUt7e3s6Ojqyb9++jI+Pp6enJ0mydevWrFmzJv39/amvr89HPvKRWcffdNNNSXLRdgBgaao6RrZs2ZKRkZHs3r07Q0NDaWtry+HDh2dOaj116lRqa13YFQCYm5pKpVIpPcS7GRsbS2NjY0ZHR9PQ0FB6HOAHWne+WHoEfuDknk2lR4CLzPX3tyUMAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKuqIY2b9/f1pbW1NfX5/Ozs4cOXLksvv+zd/8Tdrb23PTTTflfe97X9ra2vLMM89c8cAAwOJSdYwcOnQovb296evry7Fjx7Ju3bp0d3fn3Llzl9z/lltuye///u9ncHAw//Zv/5aenp709PTk7//+79/z8ADAwldTqVQq1RzQ2dmZDRs25PHHH0+STE9Pp6WlJQ8++GB27tw5p8f46Z/+6WzatCmPPPLInPYfGxtLY2NjRkdH09DQUM24wDXUuvPF0iPwAyf3bCo9Alxkrr+/q1oZmZyczNGjR9PV1fX2A9TWpqurK4ODg+96fKVSycDAQF5//fX87M/+7GX3m5iYyNjY2KwbALA4VRUj58+fz9TUVJqammZtb2pqytDQ0GWPGx0dzcqVK7N8+fJs2rQpjz32WH7u537usvv39/ensbFx5tbS0lLNmADAAnJdvk2zatWqHD9+PK+++mr+6I/+KL29vXnllVcuu/+uXbsyOjo6czt9+vT1GBMAKOCGanZevXp1li1bluHh4Vnbh4eH09zcfNnjamtrc+eddyZJ2tra8h//8R/p7+/P/ffff8n96+rqUldXV81oAMACVdXKyPLly7N+/foMDAzMbJuens7AwEA2btw458eZnp7OxMRENU8NACxSVa2MJElvb2+2bduW9vb2dHR0ZN++fRkfH09PT0+SZOvWrVmzZk36+/uTfP/8j/b29txxxx2ZmJjISy+9lGeeeSZPPPHE1X0lAMCCVHWMbNmyJSMjI9m9e3eGhobS1taWw4cPz5zUeurUqdTWvr3gMj4+nt/5nd/Jd77zndx4441Zu3Ztnn322WzZsuXqvQoAYMGq+jojJbjOCMxPrjMyf7jOCPPRNbnOCADA1SZGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFXFCP79+9Pa2tr6uvr09nZmSNHjlx236eeeiof//jHc/PNN+fmm29OV1fXO+4PACwtVcfIoUOH0tvbm76+vhw7dizr1q1Ld3d3zp07d8n9X3nllXzqU5/KP/7jP2ZwcDAtLS355Cc/mTNnzrzn4QGAha+mUqlUqjmgs7MzGzZsyOOPP54kmZ6eTktLSx588MHs3LnzXY+fmprKzTffnMcffzxbt26d03OOjY2lsbExo6OjaWhoqGZc4Bpq3fli6RH4gZN7NpUeAS4y19/fVa2MTE5O5ujRo+nq6nr7AWpr09XVlcHBwTk9xltvvZXvfe97ueWWW6p5agBgkbqhmp3Pnz+fqampNDU1zdre1NSUEydOzOkxPv/5z+fWW2+dFTQ/amJiIhMTEzP3x8bGqhkTAFhAruu3afbs2ZODBw/ma1/7Wurr6y+7X39/fxobG2duLS0t13FKAOB6qipGVq9enWXLlmV4eHjW9uHh4TQ3N7/jsXv37s2ePXvyD//wD/noRz/6jvvu2rUro6OjM7fTp09XMyYAsIBUFSPLly/P+vXrMzAwMLNteno6AwMD2bhx42WP+9KXvpRHHnkkhw8fTnt7+7s+T11dXRoaGmbdAIDFqapzRpKkt7c327ZtS3t7ezo6OrJv376Mj4+np6cnSbJ169asWbMm/f39SZI//uM/zu7du/Pcc8+ltbU1Q0NDSZKVK1dm5cqVV/GlAAALUdUxsmXLloyMjGT37t0ZGhpKW1tbDh8+PHNS66lTp1Jb+/aCyxNPPJHJycn86q/+6qzH6evryx/+4R++t+kBgAWv6uuMlOA6IzA/uc7I/OE6I8xH1+Q6IwAAV5sYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgqCuKkf3796e1tTX19fXp7OzMkSNHLrvva6+9ll/5lV9Ja2trampqsm/fviudFQBYhKqOkUOHDqW3tzd9fX05duxY1q1bl+7u7pw7d+6S+7/11lu5/fbbs2fPnjQ3N7/ngQGAxaXqGHn00Uezffv29PT05O67786BAweyYsWKPP3005fcf8OGDfnyl7+c3/iN30hdXd17HhgAWFyqipHJyckcPXo0XV1dbz9AbW26uroyODh41YaamJjI2NjYrBsAsDhVFSPnz5/P1NRUmpqaZm1vamrK0NDQVRuqv78/jY2NM7eWlpar9tgAwPwyL79Ns2vXroyOjs7cTp8+XXokAOAauaGanVevXp1ly5ZleHh41vbh4eGrenJqXV2d80sAYImoamVk+fLlWb9+fQYGBma2TU9PZ2BgIBs3brzqwwEAi19VKyNJ0tvbm23btqW9vT0dHR3Zt29fxsfH09PTkyTZunVr1qxZk/7+/iTfP+n13//932f+95kzZ3L8+PGsXLkyd95551V8KQDAQlR1jGzZsiUjIyPZvXt3hoaG0tbWlsOHD8+c1Hrq1KnU1r694HL27Nl87GMfm7m/d+/e7N27N/fdd19eeeWV9/4KAIAFraZSqVRKD/FuxsbG0tjYmNHR0TQ0NJQeB/iB1p0vlh6BHzi5Z1PpEeAic/39PS+/TQMALB1iBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEXdUHoAYGFo3fli6RGARcrKCABQlBgBAIoSIwBAUWIEAChKjAAARYkRAKAoMQIAFCVGAICixAgAUJQYAQCKEiMAQFFiBAAoSowAAEWJEQCgKDECABQlRgCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABF3VB6AADeu9adL866f3LPpkKTQPWsjAAARYkRAKAoMQIAFHVFMbJ///60tramvr4+nZ2dOXLkyDvu/9d//ddZu3Zt6uvr81M/9VN56aWXrmhYAGDxqTpGDh06lN7e3vT19eXYsWNZt25duru7c+7cuUvu/41vfCOf+tSn8sADD+Sb3/xmNm/enM2bN+fb3/72ex4eAFj4aiqVSqWaAzo7O7Nhw4Y8/vjjSZLp6em0tLTkwQcfzM6dOy/af8uWLRkfH8/f/d3fzWy755570tbWlgMHDszpOcfGxtLY2JjR0dE0NDRUMy5wlfzotzWY33ybhvlgrr+/q/pq7+TkZI4ePZpdu3bNbKutrU1XV1cGBwcveczg4GB6e3tnbevu7s7zzz9/2eeZmJjIxMTEzP3R0dEk339RQBnTE2+VHoEq+O8l88EP/3/4buseVcXI+fPnMzU1laamplnbm5qacuLEiUseMzQ0dMn9h4aGLvs8/f39efjhhy/a3tLSUs24AEtW477SE8DbLly4kMbGxsv+fF5e9GzXrl2zVlOmp6fzP//zP3n/+9+fmpqagpMtTGNjY2lpacnp06d9zDUPeD/mD+/F/OG9mF+u1vtRqVRy4cKF3Hrrre+4X1Uxsnr16ixbtizDw8Oztg8PD6e5ufmSxzQ3N1e1f5LU1dWlrq5u1rabbrqpmlG5hIaGBv/I5xHvx/zhvZg/vBfzy9V4P95pReSHqvo2zfLly7N+/foMDAzMbJuens7AwEA2btx4yWM2btw4a/8kefnlly+7PwCwtFT9MU1vb2+2bduW9vb2dHR0ZN++fRkfH09PT0+SZOvWrVmzZk36+/uTJA899FDuu+++/Mmf/Ek2bdqUgwcP5l//9V/zZ3/2Z1f3lQAAC1LVMbJly5aMjIxk9+7dGRoaSltbWw4fPjxzkuqpU6dSW/v2gsu9996b5557Ln/wB3+QL3zhC7nrrrvy/PPP5yMf+cjVexW8o7q6uvT19V300RdleD/mD+/F/OG9mF+u9/tR9XVGAACuJn+bBgAoSowAAEWJEQCgKDECABQlRpawiYmJtLW1paamJsePHy89zpJz8uTJPPDAA7ntttty44035o477khfX18mJydLj7Zk7N+/P62tramvr09nZ2eOHDlSeqQlp7+/Pxs2bMiqVavygQ98IJs3b87rr79eeiyS7NmzJzU1Nfnc5z53zZ9LjCxhv/d7v/eul+jl2jlx4kSmp6fz5JNP5rXXXsuf/umf5sCBA/nCF75QerQl4dChQ+nt7U1fX1+OHTuWdevWpbu7O+fOnSs92pLy9a9/PTt27Mi//Mu/5OWXX873vve9fPKTn8z4+Hjp0Za0V199NU8++WQ++tGPXp8nrLAkvfTSS5W1a9dWXnvttUqSyje/+c3SI1GpVL70pS9VbrvtttJjLAkdHR2VHTt2zNyfmpqq3HrrrZX+/v6CU3Hu3LlKksrXv/710qMsWRcuXKjcddddlZdffrly3333VR566KFr/pxWRpag4eHhbN++Pc8880xWrFhRehz+P6Ojo7nllltKj7HoTU5O5ujRo+nq6prZVltbm66urgwODhacjNHR0STx76CgHTt2ZNOmTbP+fVxr8/Kv9nLtVCqVfOYzn8lv//Zvp729PSdPniw9Ej/wxhtv5LHHHsvevXtLj7LonT9/PlNTUzNXjv6hpqamnDhxotBUTE9P53Of+1x+5md+xlW6Czl48GCOHTuWV1999bo+r5WRRWLnzp2pqal5x9uJEyfy2GOP5cKFC9m1a1fpkRetub4X/78zZ87k53/+5/Nrv/Zr2b59e6HJoawdO3bk29/+dg4ePFh6lCXp9OnTeeihh/JXf/VXqa+vv67P7XLwi8TIyEj++7//+x33uf322/Prv/7r+du//dvU1NTMbJ+amsqyZcvy6U9/On/5l395rUdd9Ob6XixfvjxJcvbs2dx///2555578pWvfGXW33bi2picnMyKFSvy1a9+NZs3b57Zvm3btnz3u9/NCy+8UG64Jeqzn/1sXnjhhfzTP/1TbrvtttLjLEnPP/98fvmXfznLli2b2TY1NZWamprU1tZmYmJi1s+uJjGyxJw6dSpjY2Mz98+ePZvu7u589atfTWdnZz74wQ8WnG7pOXPmTD7xiU9k/fr1efbZZ6/ZP3Qu1tnZmY6Ojjz22GNJvv8RwYc+9KF89rOfzc6dOwtPt3RUKpU8+OCD+drXvpZXXnkld911V+mRlqwLFy7kzTffnLWtp6cna9euzec///lr+tGZc0aWmA996EOz7q9cuTJJcscddwiR6+zMmTO5//778+EPfzh79+7NyMjIzM+am5sLTrY09Pb2Ztu2bWlvb09HR0f27duX8fHx9PT0lB5tSdmxY0eee+65vPDCC1m1alWGhoaSJI2NjbnxxhsLT7e0rFq16qLgeN/73pf3v//91/wcHjEChbz88st544038sYbb1wUghYsr70tW7ZkZGQku3fvztDQUNra2nL48OGLTmrl2nriiSeSJPfff/+s7X/xF3+Rz3zmM9d/IIrwMQ0AUJQz5QCAosQIAFCUGAEAihIjAEBRYgQAKEqMAABFiREAoCgxAgAUJUYAgKLECABQlBgBAIoSIwBAUf8PYM8BUYV01AQAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#samples = flow().sample((100000,))\n",
    "samples = flow(torch.randn(1000000, 1)).sample()\n",
    "print(samples.shape)\n",
    "plt.hist(samples.detach().numpy(), bins=100, density=True)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "app",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
